I tried running inference with the 2B model from https://github.com/google-deepmind/gemma on my M2 MacBook Pro, but it segfaults during sampling: https://pastebin.com/KECyz60T
Note: out of the box it will try to load bfloat16 weights, which will fail. To avoid this, I patched line 30 in gemma/params.py to explicitly cast to float32:
param_state = jax.tree_util.tree_map(lambda p: jnp.array(p, jnp.float32), params)